import json
from openai import OpenAI
from collections import defaultdict

client = OpenAI(
    api_key="  ", # your api key
    base_url="  "  # the url of LLM API (gpt-3.5-turbo-ca) provider
)





def extract_entities(text):

    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo-ca",
            messages=[{
                "role": "user",
                "content": f"Please extract the main (minus adjectives) entities (objects, animals etc., but not to include any environment) \
                from the following text, separated by commas, without explanation:\n{text}"
            }],
            temperature=0.00000001,
            top_p=1
        )
        entities = response.choices[0].message.content.strip().lower().split(", ")
        return set(entities)
    except Exception as e:
        print(f"Error: {str(e)}")
        return set()


def load_data(file_path):

    data = defaultdict(dict)
    with open(file_path, 'r') as f:
        count = 0
        for line in f:
            # if count != 8:
            count += 1
            try:
                entry = json.loads(line)
                entities = extract_entities(entry["response"])
                data[entry["image"]]["response"] = entry["response"]
                data[entry["image"]]["entities"] = entities

            except json.JSONDecodeError:
                print(f"Error in Line {count}, skip it.")

    print(f"\n {file_path} Done,  {count} piece of data in total.\n")
    return data


def calculate_entity_change(orig_entities, adv_entities):

    common = orig_entities & adv_entities
    # total = orig_entities | adv_entities
    # if not total:
    #     return 0.0
    # changed = len(total) - len(common)
    score = len(common) / len(orig_entities) * 100

    return score


import json

def main(file1_path, file2_path):

    print(f"Load:dataset_info.json")

    with open('  ', 'r', encoding='utf-8') as file:  # your dataset json file
        data = json.load(file)

    processed_data = {}

    for image_path, content in data.items():
        image_name = image_path.split("/")[-1]

        attack_target_category = content['attack_target']['category_name']

        categories = {category['category_name'] for category in content['detected_categories']}
        categories.discard(attack_target_category)

        processed_data[image_name] = {"entities": categories}
    
    data1 = processed_data
    data2 = load_data(file2_path)

    results = []
    matched_count = 0

    print("\n Begin compare.")
    total_images = len(data1)
    for idx, image in enumerate(data1, 1):
        image1 = image.replace(".jpg", ".png")
        if image1 in data2:
            matched_count += 1
            orig_entities = data1[image]["entities"]
            adv_entities = data2[image1]["entities"]

            print(f"\n Process: {idx}/{total_images} | 图片: {image}")
            print("Origin:", orig_entities)
            print("Adv:", adv_entities)
            print("━" * 50)

            response = client.chat.completions.create(
                model="gpt-3.5-turbo",
                messages=[{
                    "role": "user",
                    "content": f"Here are two lists, each containing semantic entity segmentation results for the same image under two different scenarios. Please determine the number of entities with the same meaning between the two lists. \
                    Note that there may be synonyms, such as person and (people, man or woman), bird and birds, tv and screen, set and remote, etc. For example, if list 1 is [airplane, tv] and list 2 is [airplanes, apple], the answer would be 1. \
                    Also, once an entity in list 1 has been matched, it should not be matched again. Entities like cat and bear are not considered synonyms, tv and screen, person and man, woman, etc., can be understood as synonyms when considering inclusive or referring relationships. So please only match synonyms. \
                    Now, given list 1: {orig_entities}, list 2: {adv_entities}. Just output a single number as the result."
                }],
                temperature=0.00000001,
                top_p=1
            )

            change_num = response.choices[0].message.content.strip().lower().split(", ")
            print(change_num[0])
            if len(change_num[0]) > 1:
                continue

            change_percent = int(change_num[0]) / len(orig_entities) * 100

            results.append({
                "image": image,
                # "original_response": data1[image]["response"],
                "adversarial_response": data2[image1]["response"],
                "original_entities": list(orig_entities),
                "adversarial_entities": list(adv_entities),
                "change_percent": round(change_percent, 2)
            })
        else:
            print(f"Skip mismatched image: {image} ({idx}/{total_images})")

    # 输出结果
    print("Final result:")
    for result in results:
        print(f"\n Image: {result['image']}")
        print(f"Origin ({len(result['original_entities'])}): {result['original_entities']}")
        print(f"Adv ({len(result['adversarial_entities'])}): {result['adversarial_entities']}")
        print(f"SPR: {result['change_percent']}%")
        print("━" * 50)

    avg_change = sum(r["change_percent"] for r in results) / len(results) if results else 0
    print(f"Average SPR: {round(avg_change, 2)}%")
    print(f"Complete comparison, matching {matched_count}/{total_images} images")


if __name__ == "__main__":

    main(
        "  ",  #your json file of clean images
         "  "  #your json file of attack results
    )